﻿using Microsoft.EntityFrameworkCore;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Linq.Dynamic.Core;
using Z.BulkOperations;

namespace EntityFrameworkCoreUtils
{
    /// <summary>
    /// 基础仓储实现
    /// https://docs.microsoft.com/zh-cn/ef/core/providers/sql-server/functions
    /// </summary>
    public partial class EfRepository<TEntity> : IRepository<TEntity> where TEntity : class
    {
        #region 参数
        private readonly IDbContext _context;
        private DbSet<TEntity> _entities;
        #endregion

        #region 构造函数
        public EfRepository(IDbContext context)
        {
            this._context = context;
        }
        #endregion

        #region 公共方法
        /// <summary>
        /// 实体更改的回滚并返回完整的错误消息
        /// </summary>
        /// <param name="exception">Exception</param>
        /// <returns>Error message</returns>
        protected string GetFullErrorTextAndRollbackEntityChanges(DbUpdateException exception)
        {
            //回滚实体
            if (_context is DbContext dbContext)
            {
                var entries = dbContext.ChangeTracker.Entries()
                    .Where(e => e.State == EntityState.Added || e.State == EntityState.Modified).ToList();

                entries.ForEach(entry => entry.State = EntityState.Unchanged);
            }
            _context.SaveChanges();
            return exception.ToString();
        }
        #endregion

        #region 方法
        /// <summary>
        /// 获取单个实体
        /// </summary>
        /// <param name="OrderBy">排序字符串</param>
        /// <param name="predicate">条件Linq表达式</param>
        /// <returns>Entity</returns>
        public virtual TEntity GetOne(string OrderBy, Expression<Func<TEntity, bool>> predicate)
        {
           // var predicate1 = LinqKit.PredicateBuilder.New<TEntity>(false);
            //predicate1.And(a => a.id == 1);

            if (OrderBy.Trim() != "")
            {
                return Entities.Where(predicate).OrderBy(OrderBy).AsNoTracking().FirstOrDefault();
            }
            else
            {
                return Entities.Where(predicate).AsNoTracking().FirstOrDefault();
            }
        }

        /// <summary>
        /// 获取记录数量
        /// </summary>
        /// <param name="predicate">条件Linq表达式</param>
        /// <returns>记录数</returns>
        public virtual int Count(Expression<Func<TEntity, bool>> predicate)
        {
            return Entities.Where(predicate).AsNoTracking().Count();
        }

        /// <summary>
        /// 获取记录数量
        /// </summary>
        /// <param name="ExpressionsSql">Sql 语句</param>
        /// <returns>记录数</returns>
        public virtual int Count(string ExpressionsSql)
        {
            return Count(ExpressionsSql, null);
        }

        /// <summary>
        /// 获取记录数量
        /// </summary>
        /// <param name="ExpressionsSql">Sql 语句,变量以:@0,@1....</param>
        /// <param name="ExpressionPartams">Sql 语句条件</param>
        /// <returns>记录数</returns>
        public virtual int Count(string ExpressionsSql, params object[] ExpressionPartams)
        {
            if (ExpressionPartams != null && ExpressionPartams.Length > 0)
            {
                return Entities.Where(ExpressionsSql, ExpressionPartams).AsNoTracking().Count();
            }
            else
            {
                return Entities.Where(ExpressionsSql).AsNoTracking().Count();
            }
        }

        /// <summary>
        /// 判断是否存在
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public virtual bool IsExists(Expression<Func<TEntity, bool>> predicate)
        {
            return Entities.Any(predicate);
        }

        public virtual List<TEntity> GetList(string OrderBy, List<Expression<Func<TEntity, bool>>> expressions)
        {
            IQueryable<TEntity> order_result= Entities;
            if (OrderBy != "")
            {
                foreach (var item in expressions)
                {
                    order_result = order_result.Where(item);
                }
                return order_result.OrderBy(OrderBy).AsNoTracking().ToList();
            }
            else {
                foreach (var item in expressions)
                {
                    order_result = order_result.Where(item);
                }
                return order_result.OrderBy(OrderBy).AsNoTracking().ToList();
            }
        }


        /// <summary>
        /// 查询列表根据Sql条件
        /// 
        /// </summary>
        /// <param name="OrderBy">排序字符串Ps:id asc</param>
        /// <param name="Expressions">id=123</param>
        /// <returns></returns>
        public virtual List<TEntity> GetList(string OrderBy, string Expressions)
        {
            return GetList(OrderBy, Expressions, null);
        }
        /// <summary>
        /// 查询列表根据Sql条件
        /// 
        /// </summary>
        /// <param name="OrderBy">排序字符串Ps:id asc</param>
        /// <param name="Expressions">id=@0</param>
        /// <param name="ExpressionPartams">125</param>
        /// <returns></returns>
        public virtual List<TEntity> GetList(string OrderBy, string Expressions, params object[] ExpressionPartams) {
            IQueryable<TEntity> order_result = Entities;
            if (OrderBy != "")
            {
                if (ExpressionPartams != null && ExpressionPartams.Length > 0)
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                else
                {
                    order_result = order_result.Where(Expressions);
                }
                return order_result.OrderBy(OrderBy).AsNoTracking().ToList();
            }
            else
            {
                if (ExpressionPartams != null && ExpressionPartams.Length > 0)
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                else
                {
                    order_result = order_result.Where(Expressions);
                }
                return order_result.OrderBy(OrderBy).AsNoTracking().ToList();
            }
        }

        public virtual PageOf<TEntity> GetPageList(int iPage, int PageSize, string OrderBy, List<Expression<Func<TEntity, bool>>> expressions)
        {
            if (OrderBy != "")
            {
                IQueryable<TEntity> order_result = Entities.OrderBy(OrderBy);
                foreach (var item in expressions)
                {
                    order_result = order_result.Where(item);
                }
                var result = order_result.PageResult(iPage, PageSize);
                return new PageOf<TEntity>()
                {
                    list = result.Queryable.ToList(),
                    page_index = result.CurrentPage,
                    page_size = result.PageSize,
                    total = result.RowCount
                };
            }
            else
            {
                IQueryable<TEntity> order_result = Entities;
                foreach (var item in expressions)
                {
                    order_result = order_result.Where(item);
                }
                var result = order_result.PageResult(iPage, PageSize);
                return new PageOf<TEntity>()
                {
                    list = result.Queryable.ToList(),
                    page_index = result.CurrentPage,
                    page_size = result.PageSize,
                    total = result.RowCount
                };
            }
        }


        /// <summary>
        /// 查询列表根据Sql条件
        /// </summary>
        /// <param name="OrderBy">排序字符串Ps:id asc</param>
        /// <param name="Expressions">id=@0</param>
        /// <returns></returns>
        public virtual PageOf<TEntity> GetPageList(int iPage, int PageSize, string OrderBy, string Expressions)
        {
            return GetPageList(iPage, PageSize,OrderBy, Expressions, null);
        }
        /// <summary>
        /// 查询列表根据Sql条件
        /// 
        /// </summary>
        /// <param name="OrderBy">排序字符串Ps:id asc</param>
        /// <param name="Expressions">id=@0</param>
        /// <param name="ExpressionPartams">125</param>
        /// <returns></returns>
        public virtual PageOf<TEntity> GetPageList(int iPage, int PageSize, string OrderBy, string Expressions, params object[] ExpressionPartams)
        {
            if (OrderBy != "")
            {
                IQueryable<TEntity> order_result = Entities.OrderBy(OrderBy);
                if (ExpressionPartams != null && ExpressionPartams.Length > 0)
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                else
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                var result = order_result.AsNoTracking().PageResult(iPage, PageSize);
                return new PageOf<TEntity>()
                {
                    list = result.Queryable.ToList(),
                    page_index = result.CurrentPage,
                    page_size = result.PageSize,
                    total = result.RowCount
                };
            }
            else
            {
                IQueryable<TEntity> order_result = Entities;
                if (ExpressionPartams != null && ExpressionPartams.Length > 0)
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                else
                {
                    order_result = order_result.Where(Expressions, ExpressionPartams);
                }
                var result = order_result.AsNoTracking().PageResult(iPage, PageSize);
                return new PageOf<TEntity>()
                {
                    list = result.Queryable.ToList(),
                    page_index = result.CurrentPage,
                    page_size = result.PageSize,
                    total = result.RowCount
                };
            }
        }
        #region 删除
        /// <summary>
        /// 按id获取实体
        /// </summary>
        /// <param name="predicate">删除条件</param>
        public virtual int Delete(Expression<Func<TEntity, bool>> predicate)
        {
           return  Entities.Where(predicate).DeleteFromQuery();
           // _context.SaveChanges();
        }

        /// <summary>
        /// 删除
        /// </summary>
        /// <param name="entity">Entity</param>
        public virtual void Delete(TEntity entity)
        {
            if (entity == null)
                throw new ArgumentNullException(nameof(entity));
            try
            {
                Entities.Remove(entity);
                _context.SaveChanges();
            }
            catch (DbUpdateException exception)
            {
                //ensure that the detailed error text is saved in the Log
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }

        /// <summary>
        /// 批量删除
        /// </summary>
        /// <param name="entities">Entities</param>
        public virtual void Delete(IEnumerable<TEntity> entities)
        {
            if (entities == null)
                throw new ArgumentNullException(nameof(entities));

            try
            {
                Entities.BulkDelete(entities);
                //_context.SaveChanges();
            }
            catch (DbUpdateException exception)
            {
                //ensure that the detailed error text is saved in the Log
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }
        #endregion


        #region 添加
        /// <summary>
        /// 添加
        /// </summary>
        /// <param name="entity">Entity</param>
        public virtual TEntity Insert(TEntity entity)
        {
            if (entity == null)
                throw new ArgumentNullException(nameof(entity));

            try
            {
                var result = Entities.Add(entity).Entity;
                _context.SaveChanges();
                return result;
            }

            catch (DbUpdateException exception)
            {
                //ensure that the detailed error text is saved in the Log
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }



        /// <summary>
        /// 批量添加
        /// </summary>
        /// <param name="entities">Entities</param>
        public virtual void BulkInsert(IEnumerable<TEntity> entities)
        {
            if (entities == null)
                throw new ArgumentNullException(nameof(entities));

            try
            {
                Entities.BulkInsert(entities);
               // _context.SaveChanges();
            }
            catch (DbUpdateException exception)
            {
                //ensure that the detailed error text is saved in the Log
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }
        #endregion

        #region 修改
        /// <summary>
        /// 修改
        /// </summary>
        /// <param name="entity">Entity</param>
        public virtual TEntity Update(TEntity entity)
        {
            if (entity == null)
                throw new ArgumentNullException(nameof(entity));

            try
            {
                var result=Entities.Update(entity).Entity;
                _context.SaveChanges();
                return result;
            }
            catch (DbUpdateException exception)
            {
                //ensure that the detailed error text is saved in the Log
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }


     

        /// <summary>
        /// 批量修改
        /// </summary>
        /// <param name="entities">Entities</param>
        public virtual void Update(IEnumerable<TEntity> entities)
        {
            if (entities == null)
                throw new ArgumentNullException(nameof(entities));

            try
            {
                Entities.BulkUpdate(entities);
               // _context.SaveChanges();
            }
            catch (DbUpdateException exception)
            {
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }
        }
        /// <summary>
        /// 根据条件修改
        /// </summary>
        /// <param name="predicate"></param>
        /// <param name="updator"></param>
        public virtual int Update(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, TEntity>> updator)
        {
            try
            {
                return Entities.Where(predicate).UpdateFromQuery(updator);
                //_context.SaveChanges();
            }
            catch (DbUpdateException exception)
            {
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }

        }

        public void BulkMerge(IEnumerable<TEntity> entities, Action<BulkOperation<TEntity>> options)
        {
            try
            {

                Entities.BulkMerge(entities, options);
            }
            catch (DbUpdateException exception)
            {
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }

        }
        public void BulkSynchronize(IEnumerable<TEntity> entities, Action<BulkOperation<TEntity>> options)
        {
            try
            {

                Entities.BulkSynchronize(entities, options);
            }
            catch (DbUpdateException exception)
            {
                throw new Exception(GetFullErrorTextAndRollbackEntityChanges(exception), exception);
            }

        }

        
        public IQueryable<TOther> OtherTable<TOther>() where TOther : class
        {
            return _context.Set<TOther>();
        }
        #endregion


        #endregion

        #region 属性
        /// <summary>
        /// 获取表
        /// </summary>
        public virtual IQueryable<TEntity> Table => Entities;

        /// <summary>
        /// 获取一个启用“no tracking”(EF特性)的表，仅当您仅为只读操作加载记录时才使用它
        /// </summary>
        public virtual IQueryable<TEntity> TableNoTracking => Entities.AsNoTracking();

        /// <summary>
        /// 获取设置模板
        /// </summary>
        public virtual DbSet<TEntity> Entities
        {
            get
            {
                if (_entities == null)
                    _entities = _context.Set<TEntity>();

                return _entities;
            }
        }
        #endregion
    }
}
